package m189tom190

import (
	"context"
	"fmt"
	"strconv"
	"strings"
	"time"

	"github.com/pkg/errors"
	"github.com/stackrox/rox/generated/storage"
	newSchema "github.com/stackrox/rox/migrator/migrations/m_189_to_m_190_vulnerability_requests_add_name/schema/new"
	"github.com/stackrox/rox/migrator/types"
	"github.com/stackrox/rox/pkg/logging"
	"github.com/stackrox/rox/pkg/postgres/pgutils"
	"github.com/stackrox/rox/pkg/protoconv"
	"github.com/stackrox/rox/pkg/sac"
	"github.com/stackrox/rox/pkg/stringutils"
	"github.com/stackrox/rox/pkg/utils"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
)

// Copied from central/vulnerabilityrequest/common/types.go.
const (
	vulnReqNameSeparator = "-"
	defaultUserShortName = "SYS"
)

var (
	batchSize = 2000
	log       = logging.LoggerForModule()
)

func migrate(database *types.Databases) error {
	ctx := sac.WithAllAccess(context.Background())
	db := database.GormDB
	pgutils.CreateTableFromModel(ctx, db, newSchema.CreateTableVulnerabilityRequestsStmt)

	// Collect existing sequence numbers, if any.
	lastKnownSeqNumInfo, err := collectExistingSeqNums(ctx, db)
	if err != nil {
		return err
	}
	return setRequestNames(ctx, db, lastKnownSeqNumInfo)
}

func collectExistingSeqNums(ctx context.Context, database *gorm.DB) (map[time.Month]int, error) {
	db := database.WithContext(ctx).Table(newSchema.VulnerabilityRequestsTableName).Select("serialized")
	rows, err := db.Rows()
	if err != nil {
		return nil, errors.Wrapf(err, "failed to iterate table %s", newSchema.VulnerabilityRequestsTableName)
	}
	defer func() { _ = rows.Close() }()

	// If we do not care if the months (1, 2, 3...) are from same year or different years.
	// We only want to keep the sequence number small by partitioning it by months.
	lastKnownSeqNumInfo := make(map[time.Month]int)
	for rows.Next() {
		var obj newSchema.VulnerabilityRequests
		if err = db.ScanRows(rows, &obj); err != nil {
			return nil, errors.Wrap(err, "failed to scan rows")
		}
		proto, err := newSchema.ConvertVulnerabilityRequestToProto(&obj)
		if err != nil {
			return nil, errors.Wrapf(err, "failed to convert %+v to proto", obj)
		}

		if proto.GetName() == "" {
			continue
		}
		month, seqNum, err := getMonthSeqNumPair(proto)
		if err != nil {
			return nil, err
		}
		lastKnownSeqNumInfo[month] = seqNum
	}
	if rows.Err() != nil {
		utils.Should(rows.Err())
		return nil, errors.Wrapf(rows.Err(), "failed to get rows for %s", newSchema.VulnerabilityRequestsTableName)
	}
	return lastKnownSeqNumInfo, nil
}

func setRequestNames(ctx context.Context, database *gorm.DB, lastKnownSeqNumInfo map[time.Month]int) error {
	db := database.WithContext(ctx).Table(newSchema.VulnerabilityRequestsTableName)
	query := database.WithContext(ctx).Table(newSchema.VulnerabilityRequestsTableName).Select("serialized")

	rows, err := query.Rows()
	if err != nil {
		return errors.Wrapf(err, "failed to iterate table %s", newSchema.VulnerabilityRequestsTableName)
	}
	defer func() { _ = rows.Close() }()

	var updatedVulnReqs []*newSchema.VulnerabilityRequests
	var count int

	// We are not assigning seq num suffix in the order in which requests were created.
	for rows.Next() {
		var obj newSchema.VulnerabilityRequests
		if err = query.ScanRows(rows, &obj); err != nil {
			return errors.Wrap(err, "failed to scan rows")
		}
		proto, err := newSchema.ConvertVulnerabilityRequestToProto(&obj)
		if err != nil {
			return errors.Wrapf(err, "failed to convert %+v to proto", obj)
		}
		if proto.GetName() != "" {
			continue
		}

		setRequestName(proto, lastKnownSeqNumInfo)
		converted, err := newSchema.ConvertVulnerabilityRequestFromProto(proto)
		if err != nil {
			return errors.Wrapf(err, "failed to convert from proto %+v", proto)
		}
		updatedVulnReqs = append(updatedVulnReqs, converted)
		count++
		if len(updatedVulnReqs) == batchSize {
			if err = db.
				Clauses(clause.OnConflict{UpdateAll: true}).
				Model(newSchema.CreateTableVulnerabilityRequestsStmt.GormModel).
				Create(&updatedVulnReqs).Error; err != nil {
				return errors.Wrapf(err, "failed to upsert converted %d objects after %d upserted", len(updatedVulnReqs), count-len(updatedVulnReqs))
			}
			updatedVulnReqs = updatedVulnReqs[:0]
		}
	}
	if rows.Err() != nil {
		utils.Should(rows.Err())
		return errors.Wrapf(rows.Err(), "failed to get rows for %s", newSchema.VulnerabilityRequestsTableName)
	}
	if len(updatedVulnReqs) > 0 {
		if err = db.
			Clauses(clause.OnConflict{UpdateAll: true}).
			Model(newSchema.CreateTableVulnerabilityRequestsStmt.GormModel).
			Create(&updatedVulnReqs).Error; err != nil {
			return errors.Wrapf(err, "failed to upsert last %d objects", len(updatedVulnReqs))
		}
	}
	log.Infof("Populated name for %d vulnerability requests", count)
	return nil
}

func getMonthSeqNumPair(req *storage.VulnerabilityRequest) (time.Month, int, error) {
	if req.GetName() == "" {
		return time.January, 0, errors.Errorf("cannot determine sequence number because vulnerability request name for %s is empty", req.GetId())
	}
	idx := strings.LastIndex(req.GetName(), vulnReqNameSeparator)
	if idx == -1 {
		return time.January, 0, errors.Errorf("sequence number not found in vulnerability request name %s", req.GetName())
	}
	i, err := strconv.Atoi(req.GetName()[idx+1:])
	if err != nil {
		return time.January, 0, errors.Errorf("could not determine the vulnerability request sequence number: %v", err)
	}
	return protoconv.ConvertTimestampToTimeOrNow(req.GetCreatedAt()).UTC().Month(), i, nil
}

func setRequestName(req *storage.VulnerabilityRequest, lastKnownSeqNumInfo map[time.Month]int) {
	req.Name = requestName(req, lastKnownSeqNumInfo)
}

// Following helper functions are copied from central/vulnerabilityrequest/manager/requestmgr/manager_impl.go.

func requestName(req *storage.VulnerabilityRequest, lastKnownSeqNumInfo map[time.Month]int) string {
	if req == nil {
		return ""
	}

	requestCreatedAt := protoconv.ConvertTimestampToTimeOrNow(req.GetCreatedAt()).UTC()
	seqNum, found := lastKnownSeqNumInfo[requestCreatedAt.Month()]
	if !found {
		lastKnownSeqNumInfo[requestCreatedAt.Month()] = 0
	}
	seqNum++
	lastKnownSeqNumInfo[requestCreatedAt.Month()] = seqNum

	userShortName := getShortName(req.GetRequestor())
	return fmt.Sprintf("%s%s%s%s%d", userShortName, vulnReqNameSeparator, requestCreatedAt.Format("060102"), vulnReqNameSeparator, seqNum)
}

func getShortName(user *storage.SlimUser) string {
	if user == nil {
		return defaultUserShortName
	}

	name := strings.ToUpper(user.GetName())
	parts := strings.Split(name, " ")
	for i := 0; i < len(parts); i++ {
		parts[i] = strings.TrimSpace(parts[i])
	}

	firstName := stringutils.FirstNonEmpty(parts...)
	lastName := stringutils.LastNonEmpty(parts...)
	if firstName != "" && lastName != "" {
		return fmt.Sprintf("%c%c", firstName[0], lastName[0])
	}
	return defaultUserShortName
}
